{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b323f441-93b8-4b61-b874-22d099a4df07",
   "metadata": {},
   "source": [
    "# OpenAI Whisper LoRA 模型评估和推理\n",
    "\n",
    "本教程将加载 LoRA 微调后的 Whisper 模型进行模型评估（基于 WER 指标），并将数据处理和模型准备流程也整合进来。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1b5a62f2-1079-4f06-8afc-9135db361d4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name_or_path = \"openai/whisper-large-v2\"\n",
    "model_dir = \"../../models/whisper-large-v2-asr-int8\"\n",
    "\n",
    "language = \"Chinese (China)\"\n",
    "language_abbr = \"zh-CN\"\n",
    "task = \"transcribe\"\n",
    "dataset_name = \"mozilla-foundation/common_voice_11_0\"\n",
    "\n",
    "batch_size=16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6c638814-d20b-4af9-b9c4-09f76ba1b631",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/david/anaconda3/envs/peft/lib/python3.10/site-packages/huggingface_hub/file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n",
      "/home/david/anaconda3/envs/peft/lib/python3.10/site-packages/huggingface_hub/file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "WhisperForConditionalGeneration(\n",
       "  (model): WhisperModel(\n",
       "    (encoder): WhisperEncoder(\n",
       "      (conv1): Conv1d(80, 1280, kernel_size=(3,), stride=(1,), padding=(1,))\n",
       "      (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))\n",
       "      (embed_positions): Embedding(1500, 1280)\n",
       "      (layers): ModuleList(\n",
       "        (0-31): 32 x WhisperEncoderLayer(\n",
       "          (self_attn): WhisperSdpaAttention(\n",
       "            (k_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=False)\n",
       "            (v_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "            (q_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "            (out_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "          )\n",
       "          (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "          (activation_fn): GELUActivation()\n",
       "          (fc1): Linear8bitLt(in_features=1280, out_features=5120, bias=True)\n",
       "          (fc2): Linear8bitLt(in_features=5120, out_features=1280, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "      )\n",
       "      (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "    (decoder): WhisperDecoder(\n",
       "      (embed_tokens): Embedding(51865, 1280, padding_idx=50257)\n",
       "      (embed_positions): WhisperPositionalEmbedding(448, 1280)\n",
       "      (layers): ModuleList(\n",
       "        (0-31): 32 x WhisperDecoderLayer(\n",
       "          (self_attn): WhisperSdpaAttention(\n",
       "            (k_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=False)\n",
       "            (v_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "            (q_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "            (out_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "          )\n",
       "          (activation_fn): GELUActivation()\n",
       "          (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "          (encoder_attn): WhisperSdpaAttention(\n",
       "            (k_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=False)\n",
       "            (v_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "            (q_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "            (out_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "          )\n",
       "          (encoder_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "          (fc1): Linear8bitLt(in_features=1280, out_features=5120, bias=True)\n",
       "          (fc2): Linear8bitLt(in_features=5120, out_features=1280, bias=True)\n",
       "          (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "      )\n",
       "      (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "  )\n",
       "  (proj_out): Linear(in_features=1280, out_features=51865, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import AutoModelForSpeechSeq2Seq, AutoTokenizer, AutoProcessor\n",
    "from peft import PeftConfig, PeftModel\n",
    "\n",
    "peft_config = PeftConfig.from_pretrained(model_dir)\n",
    "\n",
    "base_model = AutoModelForSpeechSeq2Seq.from_pretrained(\n",
    "    peft_config.base_model_name_or_path, load_in_8bit=True, device_map=\"auto\"\n",
    ")\n",
    "base_model.requires_grad_(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "24a1d9e4-e243-46e4-bff5-8d51ab302d21",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PeftModel(\n",
       "  (base_model): LoraModel(\n",
       "    (model): WhisperForConditionalGeneration(\n",
       "      (model): WhisperModel(\n",
       "        (encoder): WhisperEncoder(\n",
       "          (conv1): Conv1d(80, 1280, kernel_size=(3,), stride=(1,), padding=(1,))\n",
       "          (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))\n",
       "          (embed_positions): Embedding(1500, 1280)\n",
       "          (layers): ModuleList(\n",
       "            (0-31): 32 x WhisperEncoderLayer(\n",
       "              (self_attn): WhisperSdpaAttention(\n",
       "                (k_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=False)\n",
       "                (v_proj): lora.Linear8bitLt(\n",
       "                  (base_layer): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "                  (lora_dropout): ModuleDict(\n",
       "                    (default): Dropout(p=0.05, inplace=False)\n",
       "                  )\n",
       "                  (lora_A): ModuleDict(\n",
       "                    (default): Linear(in_features=1280, out_features=4, bias=False)\n",
       "                  )\n",
       "                  (lora_B): ModuleDict(\n",
       "                    (default): Linear(in_features=4, out_features=1280, bias=False)\n",
       "                  )\n",
       "                  (lora_embedding_A): ParameterDict()\n",
       "                  (lora_embedding_B): ParameterDict()\n",
       "                )\n",
       "                (q_proj): lora.Linear8bitLt(\n",
       "                  (base_layer): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "                  (lora_dropout): ModuleDict(\n",
       "                    (default): Dropout(p=0.05, inplace=False)\n",
       "                  )\n",
       "                  (lora_A): ModuleDict(\n",
       "                    (default): Linear(in_features=1280, out_features=4, bias=False)\n",
       "                  )\n",
       "                  (lora_B): ModuleDict(\n",
       "                    (default): Linear(in_features=4, out_features=1280, bias=False)\n",
       "                  )\n",
       "                  (lora_embedding_A): ParameterDict()\n",
       "                  (lora_embedding_B): ParameterDict()\n",
       "                )\n",
       "                (out_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "              )\n",
       "              (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "              (activation_fn): GELUActivation()\n",
       "              (fc1): Linear8bitLt(in_features=1280, out_features=5120, bias=True)\n",
       "              (fc2): Linear8bitLt(in_features=5120, out_features=1280, bias=True)\n",
       "              (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "            )\n",
       "          )\n",
       "          (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (decoder): WhisperDecoder(\n",
       "          (embed_tokens): Embedding(51865, 1280, padding_idx=50257)\n",
       "          (embed_positions): WhisperPositionalEmbedding(448, 1280)\n",
       "          (layers): ModuleList(\n",
       "            (0-31): 32 x WhisperDecoderLayer(\n",
       "              (self_attn): WhisperSdpaAttention(\n",
       "                (k_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=False)\n",
       "                (v_proj): lora.Linear8bitLt(\n",
       "                  (base_layer): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "                  (lora_dropout): ModuleDict(\n",
       "                    (default): Dropout(p=0.05, inplace=False)\n",
       "                  )\n",
       "                  (lora_A): ModuleDict(\n",
       "                    (default): Linear(in_features=1280, out_features=4, bias=False)\n",
       "                  )\n",
       "                  (lora_B): ModuleDict(\n",
       "                    (default): Linear(in_features=4, out_features=1280, bias=False)\n",
       "                  )\n",
       "                  (lora_embedding_A): ParameterDict()\n",
       "                  (lora_embedding_B): ParameterDict()\n",
       "                )\n",
       "                (q_proj): lora.Linear8bitLt(\n",
       "                  (base_layer): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "                  (lora_dropout): ModuleDict(\n",
       "                    (default): Dropout(p=0.05, inplace=False)\n",
       "                  )\n",
       "                  (lora_A): ModuleDict(\n",
       "                    (default): Linear(in_features=1280, out_features=4, bias=False)\n",
       "                  )\n",
       "                  (lora_B): ModuleDict(\n",
       "                    (default): Linear(in_features=4, out_features=1280, bias=False)\n",
       "                  )\n",
       "                  (lora_embedding_A): ParameterDict()\n",
       "                  (lora_embedding_B): ParameterDict()\n",
       "                )\n",
       "                (out_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "              )\n",
       "              (activation_fn): GELUActivation()\n",
       "              (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "              (encoder_attn): WhisperSdpaAttention(\n",
       "                (k_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=False)\n",
       "                (v_proj): lora.Linear8bitLt(\n",
       "                  (base_layer): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "                  (lora_dropout): ModuleDict(\n",
       "                    (default): Dropout(p=0.05, inplace=False)\n",
       "                  )\n",
       "                  (lora_A): ModuleDict(\n",
       "                    (default): Linear(in_features=1280, out_features=4, bias=False)\n",
       "                  )\n",
       "                  (lora_B): ModuleDict(\n",
       "                    (default): Linear(in_features=4, out_features=1280, bias=False)\n",
       "                  )\n",
       "                  (lora_embedding_A): ParameterDict()\n",
       "                  (lora_embedding_B): ParameterDict()\n",
       "                )\n",
       "                (q_proj): lora.Linear8bitLt(\n",
       "                  (base_layer): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "                  (lora_dropout): ModuleDict(\n",
       "                    (default): Dropout(p=0.05, inplace=False)\n",
       "                  )\n",
       "                  (lora_A): ModuleDict(\n",
       "                    (default): Linear(in_features=1280, out_features=4, bias=False)\n",
       "                  )\n",
       "                  (lora_B): ModuleDict(\n",
       "                    (default): Linear(in_features=4, out_features=1280, bias=False)\n",
       "                  )\n",
       "                  (lora_embedding_A): ParameterDict()\n",
       "                  (lora_embedding_B): ParameterDict()\n",
       "                )\n",
       "                (out_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "              )\n",
       "              (encoder_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "              (fc1): Linear8bitLt(in_features=1280, out_features=5120, bias=True)\n",
       "              (fc2): Linear8bitLt(in_features=5120, out_features=1280, bias=True)\n",
       "              (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "            )\n",
       "          )\n",
       "          (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "      )\n",
       "      (proj_out): Linear(in_features=1280, out_features=51865, bias=False)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "peft_model = PeftModel.from_pretrained(base_model, model_dir)\n",
    "peft_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cdb02d5f-07f5-4ef1-87cb-0c81925bf586",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)\n",
    "processor = AutoProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)\n",
    "feature_extractor = processor.feature_extractor"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec285096-52a5-4e7a-9520-451c1b17c349",
   "metadata": {},
   "source": [
    "## 评估数据集处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9646f874-182c-4a6b-b000-b35f90fc893e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset, DatasetDict, Audio\n",
    "\n",
    "common_voice = DatasetDict()\n",
    "common_voice[\"test\"] = load_dataset(dataset_name, language_abbr, split=\"test\", trust_remote_code=True)\n",
    "common_voice = common_voice.remove_columns(\n",
    "    [\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"path\", \"segment\", \"up_votes\"]\n",
    ")\n",
    "common_voice = common_voice.cast_column(\"audio\", Audio(sampling_rate=16000))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2a0f0d3f-71d5-4cae-8230-70b1fb4548d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_dataset(batch):\n",
    "    audio = batch[\"audio\"]\n",
    "    batch[\"input_features\"] = feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n",
    "    batch[\"labels\"] = tokenizer(batch[\"sentence\"]).input_ids\n",
    "    return batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "be3d5a3c-d498-4250-a590-ccd51e9b03f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "small_common_voice = DatasetDict()\n",
    "\n",
    "small_common_voice[\"test\"] = common_voice[\"test\"].shuffle(seed=16).select(range(320))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7c4e47ac-2001-47c0-a53a-9490c94bc784",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_common_voice = small_common_voice.map(prepare_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3875e35-9c9f-47d6-be11-ee553723179d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f2aa0d9e-8ca7-47ed-8856-cebce1d93c65",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from dataclasses import dataclass\n",
    "from typing import Any, Dict, List, Union\n",
    "\n",
    "# 定义一个针对语音到文本任务的数据整理器类\n",
    "@dataclass\n",
    "class DataCollatorSpeechSeq2SeqWithPadding:\n",
    "    processor: Any  # 处理器结合了特征提取器和分词器\n",
    "\n",
    "    # 整理器函数，将特征列表处理成一个批次\n",
    "    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
    "        # 从特征列表中提取输入特征，并填充以使它们具有相同的形状\n",
    "        input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n",
    "        batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n",
    "\n",
    "        # 从特征列表中提取标签特征（文本令牌），并进行填充\n",
    "        label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
    "        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n",
    "\n",
    "        # 使用-100替换标签中的填充区域，-100通常用于在损失计算中忽略填充令牌\n",
    "        labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
    "\n",
    "        # 如果批次中的所有序列都以句子开始令牌开头，则移除它\n",
    "        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():\n",
    "            labels = labels[:, 1:]\n",
    "\n",
    "        # 将处理过的标签添加到批次中\n",
    "        batch[\"labels\"] = labels\n",
    "\n",
    "        return batch  # 返回最终的批次，准备好进行训练或评估\n",
    "\n",
    "# 用给定的处理器实例化数据整理器\n",
    "data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7eeb613-1dfb-45f0-828d-36150ed5b335",
   "metadata": {},
   "source": [
    "## 评估模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3e774b45-69a9-4f95-938d-e3284ff6060f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import evaluate\n",
    "\n",
    "# 词错误率（WER）是评估ASR模型常用的指标。从 Evaluate 加载 WER 指标\n",
    "metric = evaluate.load(\"wer\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6e5c0c9d-5ccd-4770-9cfc-78e51c956baa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import gc\n",
    "\n",
    "eval_dataloader = DataLoader(tokenized_common_voice[\"test\"], batch_size=batch_size, collate_fn=data_collator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e1371372-b803-4cb0-96a2-110df3d52541",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/david/anaconda3/envs/peft/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
      "  warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
      "  0%|                                                                    | 0/20 [00:12<?, ?it/s]\n"
     ]
    },
    {
     "ename": "OutOfMemoryError",
     "evalue": "CUDA out of memory. Tried to allocate 118.00 MiB. GPU ",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mOutOfMemoryError\u001b[0m                          Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[14], line 9\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mamp\u001b[38;5;241m.\u001b[39mautocast():\n\u001b[1;32m      5\u001b[0m     \u001b[38;5;66;03m# 不计算梯度，以节省计算资源，仅用于推理和评估\u001b[39;00m\n\u001b[1;32m      6\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m      7\u001b[0m         \u001b[38;5;66;03m# 生成预测的标记(tokens)，这里使用模型的generate函数进行文本生成\u001b[39;00m\n\u001b[1;32m      8\u001b[0m         generated_tokens \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m----> 9\u001b[0m             \u001b[43mpeft_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     10\u001b[0m \u001b[43m                \u001b[49m\u001b[43minput_features\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43minput_features\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcuda\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# 将输入特征移动到GPU上\u001b[39;49;00m\n\u001b[1;32m     11\u001b[0m \u001b[43m                \u001b[49m\u001b[43mdecoder_input_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlabels\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcuda\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# 提供解码器的初始输入\u001b[39;49;00m\n\u001b[1;32m     12\u001b[0m \u001b[43m                \u001b[49m\u001b[43mmax_new_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m255\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# 设置生成的最大新标记数量\u001b[39;49;00m\n\u001b[1;32m     13\u001b[0m \u001b[43m            \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     14\u001b[0m             \u001b[38;5;241m.\u001b[39mcpu()  \u001b[38;5;66;03m# 将生成的标记移回CPU\u001b[39;00m\n\u001b[1;32m     15\u001b[0m             \u001b[38;5;241m.\u001b[39mnumpy()  \u001b[38;5;66;03m# 转换为NumPy数组以便进一步处理\u001b[39;00m\n\u001b[1;32m     16\u001b[0m         )\n\u001b[1;32m     17\u001b[0m         \u001b[38;5;66;03m# 获取批次中的标签，并将其移回CPU\u001b[39;00m\n\u001b[1;32m     18\u001b[0m         labels \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy()\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:543\u001b[0m, in \u001b[0;36mWhisperGenerationMixin.generate\u001b[0;34m(self, input_features, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_timestamps, task, language, is_multilingual, prompt_ids, condition_on_prev_tokens, temperature, compression_ratio_threshold, logprob_threshold, no_speech_threshold, num_segment_frames, attention_mask, time_precision, return_token_timestamps, return_segments, return_dict_in_generate, **kwargs)\u001b[0m\n\u001b[1;32m    540\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m temperature \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    541\u001b[0m     kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtemperature\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m temperature\n\u001b[0;32m--> 543\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    544\u001b[0m \u001b[43m    \u001b[49m\u001b[43minput_features\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    545\u001b[0m \u001b[43m    \u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgeneration_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    546\u001b[0m \u001b[43m    \u001b[49m\u001b[43mlogits_processor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlogits_processor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    547\u001b[0m \u001b[43m    \u001b[49m\u001b[43mstopping_criteria\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstopping_criteria\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    548\u001b[0m \u001b[43m    \u001b[49m\u001b[43mprefix_allowed_tokens_fn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprefix_allowed_tokens_fn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    549\u001b[0m \u001b[43m    \u001b[49m\u001b[43msynced_gpus\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msynced_gpus\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    550\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    551\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    553\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m generation_config\u001b[38;5;241m.\u001b[39mreturn_token_timestamps \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(generation_config, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124malignment_heads\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m    554\u001b[0m     outputs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtoken_timestamps\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_extract_token_timestamps(\n\u001b[1;32m    555\u001b[0m         outputs, generation_config\u001b[38;5;241m.\u001b[39malignment_heads, num_frames\u001b[38;5;241m=\u001b[39mgeneration_config\u001b[38;5;241m.\u001b[39mnum_frames\n\u001b[1;32m    556\u001b[0m     )\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.10/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m    113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    114\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.10/site-packages/transformers/generation/utils.py:1479\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)\u001b[0m\n\u001b[1;32m   1462\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39massisted_decoding(\n\u001b[1;32m   1463\u001b[0m         input_ids,\n\u001b[1;32m   1464\u001b[0m         candidate_generator\u001b[38;5;241m=\u001b[39mcandidate_generator,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1475\u001b[0m         \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs,\n\u001b[1;32m   1476\u001b[0m     )\n\u001b[1;32m   1477\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m generation_mode \u001b[38;5;241m==\u001b[39m GenerationMode\u001b[38;5;241m.\u001b[39mGREEDY_SEARCH:\n\u001b[1;32m   1478\u001b[0m     \u001b[38;5;66;03m# 11. run greedy search\u001b[39;00m\n\u001b[0;32m-> 1479\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgreedy_search\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1480\u001b[0m \u001b[43m        \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1481\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlogits_processor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprepared_logits_processor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1482\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstopping_criteria\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprepared_stopping_criteria\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1483\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpad_token_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgeneration_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpad_token_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1484\u001b[0m \u001b[43m        \u001b[49m\u001b[43meos_token_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgeneration_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43meos_token_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1485\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_scores\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgeneration_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moutput_scores\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1486\u001b[0m \u001b[43m        \u001b[49m\u001b[43mreturn_dict_in_generate\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgeneration_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreturn_dict_in_generate\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1487\u001b[0m \u001b[43m        \u001b[49m\u001b[43msynced_gpus\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msynced_gpus\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1488\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstreamer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstreamer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1489\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1490\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1492\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m generation_mode \u001b[38;5;241m==\u001b[39m GenerationMode\u001b[38;5;241m.\u001b[39mCONTRASTIVE_SEARCH:\n\u001b[1;32m   1493\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m model_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muse_cache\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.10/site-packages/transformers/generation/utils.py:2340\u001b[0m, in \u001b[0;36mGenerationMixin.greedy_search\u001b[0;34m(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)\u001b[0m\n\u001b[1;32m   2337\u001b[0m model_inputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprepare_inputs_for_generation(input_ids, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs)\n\u001b[1;32m   2339\u001b[0m \u001b[38;5;66;03m# forward pass to get next token\u001b[39;00m\n\u001b[0;32m-> 2340\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[1;32m   2341\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2342\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m   2343\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2344\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2345\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m synced_gpus \u001b[38;5;129;01mand\u001b[39;00m this_peer_finished:\n\u001b[1;32m   2348\u001b[0m     \u001b[38;5;28;01mcontinue\u001b[39;00m  \u001b[38;5;66;03m# don't waste resources running the code we don't need\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.10/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1530\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.10/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1539\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1540\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1544\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.10/site-packages/accelerate/hooks.py:166\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(module, *args, **kwargs)\u001b[0m\n\u001b[1;32m    164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    165\u001b[0m     output \u001b[38;5;241m=\u001b[39m module\u001b[38;5;241m.\u001b[39m_old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_hf_hook\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpost_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.10/site-packages/accelerate/hooks.py:315\u001b[0m, in \u001b[0;36mAlignDevicesHook.post_forward\u001b[0;34m(self, module, output)\u001b[0m\n\u001b[1;32m    312\u001b[0m             module\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mCxB \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m    314\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mio_same_device \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_device \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 315\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[43msend_to_device\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutput\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minput_device\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mskip_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mskip_keys\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    317\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.10/site-packages/accelerate/utils/operations.py:161\u001b[0m, in \u001b[0;36msend_to_device\u001b[0;34m(tensor, device, non_blocking, skip_keys)\u001b[0m\n\u001b[1;32m    158\u001b[0m     \u001b[38;5;28;01melif\u001b[39;00m skip_keys \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    159\u001b[0m         skip_keys \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m    160\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(tensor)(\n\u001b[0;32m--> 161\u001b[0m         {\n\u001b[1;32m    162\u001b[0m             k: t \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m skip_keys \u001b[38;5;28;01melse\u001b[39;00m send_to_device(t, device, non_blocking\u001b[38;5;241m=\u001b[39mnon_blocking, skip_keys\u001b[38;5;241m=\u001b[39mskip_keys)\n\u001b[1;32m    163\u001b[0m             \u001b[38;5;28;01mfor\u001b[39;00m k, t \u001b[38;5;129;01min\u001b[39;00m tensor\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m    164\u001b[0m         }\n\u001b[1;32m    165\u001b[0m     )\n\u001b[1;32m    166\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(tensor, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mto\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m    167\u001b[0m     \u001b[38;5;66;03m# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).\u001b[39;00m\n\u001b[1;32m    168\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m is_npu_available() \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(device, \u001b[38;5;28mint\u001b[39m):\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.10/site-packages/accelerate/utils/operations.py:162\u001b[0m, in \u001b[0;36m<dictcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    158\u001b[0m     \u001b[38;5;28;01melif\u001b[39;00m skip_keys \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    159\u001b[0m         skip_keys \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m    160\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(tensor)(\n\u001b[1;32m    161\u001b[0m         {\n\u001b[0;32m--> 162\u001b[0m             k: t \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m skip_keys \u001b[38;5;28;01melse\u001b[39;00m \u001b[43msend_to_device\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnon_blocking\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnon_blocking\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mskip_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskip_keys\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    163\u001b[0m             \u001b[38;5;28;01mfor\u001b[39;00m k, t \u001b[38;5;129;01min\u001b[39;00m tensor\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m    164\u001b[0m         }\n\u001b[1;32m    165\u001b[0m     )\n\u001b[1;32m    166\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(tensor, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mto\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m    167\u001b[0m     \u001b[38;5;66;03m# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).\u001b[39;00m\n\u001b[1;32m    168\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m is_npu_available() \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(device, \u001b[38;5;28mint\u001b[39m):\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.10/site-packages/accelerate/utils/operations.py:152\u001b[0m, in \u001b[0;36msend_to_device\u001b[0;34m(tensor, device, non_blocking, skip_keys)\u001b[0m\n\u001b[1;32m    139\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    140\u001b[0m \u001b[38;5;124;03mRecursively sends the elements in a nested list/tuple/dictionary of tensors to a given device.\u001b[39;00m\n\u001b[1;32m    141\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    149\u001b[0m \u001b[38;5;124;03m    The same data structure as `tensor` with all tensors sent to the proper device.\u001b[39;00m\n\u001b[1;32m    150\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    151\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(tensor, (\u001b[38;5;28mtuple\u001b[39m, \u001b[38;5;28mlist\u001b[39m)):\n\u001b[0;32m--> 152\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mhonor_type\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    153\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtensor\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43msend_to_device\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnon_blocking\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnon_blocking\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mskip_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskip_keys\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtensor\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    154\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    155\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(tensor, Mapping):\n\u001b[1;32m    156\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(skip_keys, \u001b[38;5;28mstr\u001b[39m):\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.10/site-packages/accelerate/utils/operations.py:84\u001b[0m, in \u001b[0;36mhonor_type\u001b[0;34m(obj, generator)\u001b[0m\n\u001b[1;32m     82\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(obj)(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28mlist\u001b[39m(generator))\n\u001b[1;32m     83\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 84\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mtype\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgenerator\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.10/site-packages/accelerate/utils/operations.py:153\u001b[0m, in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    139\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    140\u001b[0m \u001b[38;5;124;03mRecursively sends the elements in a nested list/tuple/dictionary of tensors to a given device.\u001b[39;00m\n\u001b[1;32m    141\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    149\u001b[0m \u001b[38;5;124;03m    The same data structure as `tensor` with all tensors sent to the proper device.\u001b[39;00m\n\u001b[1;32m    150\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    151\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(tensor, (\u001b[38;5;28mtuple\u001b[39m, \u001b[38;5;28mlist\u001b[39m)):\n\u001b[1;32m    152\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m honor_type(\n\u001b[0;32m--> 153\u001b[0m         tensor, (\u001b[43msend_to_device\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnon_blocking\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnon_blocking\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mskip_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskip_keys\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m t \u001b[38;5;129;01min\u001b[39;00m tensor)\n\u001b[1;32m    154\u001b[0m     )\n\u001b[1;32m    155\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(tensor, Mapping):\n\u001b[1;32m    156\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(skip_keys, \u001b[38;5;28mstr\u001b[39m):\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.10/site-packages/accelerate/utils/operations.py:152\u001b[0m, in \u001b[0;36msend_to_device\u001b[0;34m(tensor, device, non_blocking, skip_keys)\u001b[0m\n\u001b[1;32m    139\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    140\u001b[0m \u001b[38;5;124;03mRecursively sends the elements in a nested list/tuple/dictionary of tensors to a given device.\u001b[39;00m\n\u001b[1;32m    141\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    149\u001b[0m \u001b[38;5;124;03m    The same data structure as `tensor` with all tensors sent to the proper device.\u001b[39;00m\n\u001b[1;32m    150\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    151\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(tensor, (\u001b[38;5;28mtuple\u001b[39m, \u001b[38;5;28mlist\u001b[39m)):\n\u001b[0;32m--> 152\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mhonor_type\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    153\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtensor\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43msend_to_device\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnon_blocking\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnon_blocking\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mskip_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskip_keys\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtensor\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    154\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    155\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(tensor, Mapping):\n\u001b[1;32m    156\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(skip_keys, \u001b[38;5;28mstr\u001b[39m):\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.10/site-packages/accelerate/utils/operations.py:84\u001b[0m, in \u001b[0;36mhonor_type\u001b[0;34m(obj, generator)\u001b[0m\n\u001b[1;32m     82\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(obj)(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28mlist\u001b[39m(generator))\n\u001b[1;32m     83\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 84\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mtype\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgenerator\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.10/site-packages/accelerate/utils/operations.py:153\u001b[0m, in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    139\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    140\u001b[0m \u001b[38;5;124;03mRecursively sends the elements in a nested list/tuple/dictionary of tensors to a given device.\u001b[39;00m\n\u001b[1;32m    141\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    149\u001b[0m \u001b[38;5;124;03m    The same data structure as `tensor` with all tensors sent to the proper device.\u001b[39;00m\n\u001b[1;32m    150\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    151\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(tensor, (\u001b[38;5;28mtuple\u001b[39m, \u001b[38;5;28mlist\u001b[39m)):\n\u001b[1;32m    152\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m honor_type(\n\u001b[0;32m--> 153\u001b[0m         tensor, (\u001b[43msend_to_device\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnon_blocking\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnon_blocking\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mskip_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskip_keys\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m t \u001b[38;5;129;01min\u001b[39;00m tensor)\n\u001b[1;32m    154\u001b[0m     )\n\u001b[1;32m    155\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(tensor, Mapping):\n\u001b[1;32m    156\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(skip_keys, \u001b[38;5;28mstr\u001b[39m):\n",
      "File \u001b[0;32m~/anaconda3/envs/peft/lib/python3.10/site-packages/accelerate/utils/operations.py:171\u001b[0m, in \u001b[0;36msend_to_device\u001b[0;34m(tensor, device, non_blocking, skip_keys)\u001b[0m\n\u001b[1;32m    169\u001b[0m     device \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnpu:\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdevice\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    170\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 171\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtensor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnon_blocking\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnon_blocking\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    172\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:  \u001b[38;5;66;03m# .to() doesn't accept non_blocking as kwarg\u001b[39;00m\n\u001b[1;32m    173\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m tensor\u001b[38;5;241m.\u001b[39mto(device)\n",
      "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 118.00 MiB. GPU "
     ]
    }
   ],
   "source": [
    "# 遍历评估数据加载器中的所有批次\n",
    "for step, batch in enumerate(tqdm(eval_dataloader)):\n",
    "    # 使用自动混合精度来加速计算，并减少显存使用\n",
    "    with torch.cuda.amp.autocast():\n",
    "        # 不计算梯度，以节省计算资源，仅用于推理和评估\n",
    "        with torch.no_grad():\n",
    "            # 生成预测的标记(tokens)，这里使用模型的generate函数进行文本生成\n",
    "            generated_tokens = (\n",
    "                peft_model.generate(\n",
    "                    input_features=batch[\"input_features\"].to(\"cuda\"),  # 将输入特征移动到GPU上\n",
    "                    decoder_input_ids=batch[\"labels\"][:, :4].to(\"cuda\"),  # 提供解码器的初始输入\n",
    "                    max_new_tokens=255,  # 设置生成的最大新标记数量\n",
    "                )\n",
    "                .cpu()  # 将生成的标记移回CPU\n",
    "                .numpy()  # 转换为NumPy数组以便进一步处理\n",
    "            )\n",
    "            # 获取批次中的标签，并将其移回CPU\n",
    "            labels = batch[\"labels\"].cpu().numpy()\n",
    "            # 将标签中的-100替换为填充标记的ID，-100通常用于忽略计算损失的标记\n",
    "            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
    "            # 使用分词器解码生成的标记和标签，以获得可读的文本\n",
    "            decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)\n",
    "            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
    "            # 将预测和参考添加到评估指标中，用于后续的性能评估\n",
    "            metric.add_batch(\n",
    "                predictions=decoded_preds,\n",
    "                references=decoded_labels,\n",
    "            )\n",
    "    # 删除不再需要的变量以释放内存\n",
    "    del generated_tokens, labels, batch\n",
    "    # 手动触发垃圾收集，进一步清理内存\n",
    "    gc.collect()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41525e3b-3fe5-45cc-9e5c-32033960f424",
   "metadata": {},
   "source": [
    "### 使用全量数据微调后，对比 WER 指标降低了多少"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "aa19ad4f-5f64-48e4-89f2-40ec2136868d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wed Mar  5 10:44:54 2025       \n",
      "+-----------------------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 570.86.15              Driver Version: 570.86.15      CUDA Version: 12.8     |\n",
      "|-----------------------------------------+------------------------+----------------------+\n",
      "| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |\n",
      "|                                         |                        |               MIG M. |\n",
      "|=========================================+========================+======================|\n",
      "|   0  Tesla T4                       Off |   00000000:03:00.0 Off |                    0 |\n",
      "| N/A   42C    P0             24W /   70W |   14901MiB /  15360MiB |      0%      Default |\n",
      "|                                         |                        |                  N/A |\n",
      "+-----------------------------------------+------------------------+----------------------+\n",
      "|   1  Tesla T4                       Off |   00000000:81:00.0 Off |                    0 |\n",
      "| N/A   42C    P0             25W /   70W |    9176MiB /  15360MiB |      0%      Default |\n",
      "|                                         |                        |                  N/A |\n",
      "+-----------------------------------------+------------------------+----------------------+\n",
      "                                                                                         \n",
      "+-----------------------------------------------------------------------------------------+\n",
      "| Processes:                                                                              |\n",
      "|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |\n",
      "|        ID   ID                                                               Usage      |\n",
      "|=========================================================================================|\n",
      "|    0   N/A  N/A            1913      G   /usr/lib/xorg/Xorg                      108MiB |\n",
      "|    0   N/A  N/A            2030      G   /usr/bin/gnome-shell                      6MiB |\n",
      "|    0   N/A  N/A            5679      C   ...naconda3/envs/peft/bin/python      14782MiB |\n",
      "|    1   N/A  N/A            1913      G   /usr/lib/xorg/Xorg                        4MiB |\n",
      "|    1   N/A  N/A            5679      C   ...naconda3/envs/peft/bin/python       9168MiB |\n",
      "+-----------------------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi  ##彭老师，我试过很多次，只有几次删除了device_map, 使用DP才可以运行，但wer=7.5...比1还大。不清楚那里出了问题。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec3a785b-4b88-43ae-bee2-5412c3961b4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 计算词错误率（WER）指标，并将结果转换为百分比形式\n",
    "wer = 100 * metric.compute()\n",
    "\n",
    "# 打印词错误率，f\"{wer=}\"是一种格式化字符串的简洁写法，它会展示变量名和值\n",
    "print(f\"{wer=}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e698dc9f-e54d-4faf-a7a4-db5cc7ce6d3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install jiwer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87256898-9b49-4c42-88a3-6899717d3ebf",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
